import Optimizer
import random

# this optimizer performs a generational genetic optimization over a given number of generations
# it uses uniform crossover
# generation size equals the number of bits in the state
class VariableNeighborSearch(Optimizer.Optimizer):
    # num_batches: how many batchs to iterate over
    # states_per_batch: how many states per batch
    def __init__(self, expandedNeighborhoodSize, batch_size, num_batches, num_bits, number_top_states, characterizer, initial_state = None, verbose=False):
        super(VariableNeighborSearch,self).__init__(number_top_states,characterizer,verbose)
        self.verbose = verbose

        self.num_batches = num_batches

        self.num_bits = num_bits
        self.maximum_state =  2**num_bits-1

        self.batch_counter = 0
        self.current_state = initial_state
        self.current_value = float("-inf")
        self.next_states = []

        self.batch_size = batch_size

        self.expandedNeighborhoodSize = expandedNeighborhoodSize


############### INTERFACE FUNCTIONS ###############

    def isFinished(self):
        if self.verbose:
            print "Checking if Finished"
        return self.batch_counter >= self.num_batches

    def getNextStates(self):
        if self.verbose:
            print "Getting Next States"
        if self.batch_counter == 0:
            if self.verbose:
                print "Selecting Initial States..."
            self.next_states = self.getInitialStates()
            self.batch_counter += 1
            return self.next_states
        if self.verbose:
            print "Comparing Choices..."
        best_neighbor_state = float("-inf")
        best_neighbor_value = float("-inf")
        for neighbor_state in self.next_states:
            neighbor_value = self.explored_states[neighbor_state]
            if neighbor_value > best_neighbor_value:
                best_neighbor_state = neighbor_state
                best_neighbor_value = neighbor_value
        if self.verbose:
            print "Checking for Local Maximum..."

        self.next_states = []
        if self.current_value < best_neighbor_value:
            if self.verbose:
                print "Improvement Detected: Moving State"
            self.current_state = best_neighbor_state
            self.current_value = best_neighbor_value
            neighborhoodSize = 1
        else:
            if self.verbose:
                print "Local Maximum Detected: Expanding Neighborhood"
            neighborhoodSize = self.expandedNeighborhoodSize
        leftover_states = self.batch_size
        # fill up the batch using randomly selected states from expanded neighborhood
        while leftover_states > 0:
            print "Selecting Neighbors from k Neighborhood"
            neighborhood = self.getNeighborhood(self.current_state,neighborhoodSize)
            possible_neighbors = [neighbor for neighbor in neighborhood if (neighbor not in self.explored_states.keys())]
            if  len(possible_neighbors) >  leftover_states:
                self.next_states += random.sample(possible_neighbors, leftover_states)
            else:
                self.next_states += possible_neighbors
                neighborhoodSize += 1
            leftover_states = self.batch_size - len(self.next_states)
            if leftover_states > 0 and self.verbose:
                print "Neighborhood explored: Adding extra expansion"
        self.batch_counter += 1
        return self.next_states

############### Child Class Helper Functions ###############

    # return a list of uniformly random states
    def getInitialStates(self):
        if self.current_state is None:
            initial_states = [random.randint(0, self.maximum_state) for i in range(self.batch_size)]
        else:
            neighborhoodSize = 1
            initial_states = [self.current_state]
            leftover_states = self.batch_size - 1
            # fill up the batch using randomly selected states from expanded neighborhood
            while leftover_states > 0:
                print "Selecting Neighbors from k Neighborhood"
                neighborhood = self.getNeighborhood(self.current_state,neighborhoodSize)
                possible_neighbors = [neighbor for neighbor in neighborhood if (neighbor not in self.explored_states.keys())]
                if  len(possible_neighbors) >  leftover_states:
                    initial_states += random.sample(possible_neighbors, leftover_states)
                else:
                    initial_states += possible_neighbors
                    neighborhoodSize += 1
                leftover_states = self.batch_size - len(initial_states)
                if leftover_states > 0 and self.verbose:
                    print "Neighborhood explored: Adding extra expansion"
        return initial_states

    def state2bin(self,state):
        bin_state = bin(state)
        state_bit_count = len(bin_state)-2
        if state_bit_count < self.num_bits:
            missing_bits = self.num_bits - state_bit_count
            bin_state = bin_state[:2] + '0'*missing_bits + bin_state[2:]
        return bin_state

    def bin2state(self,bin_state):
        return int(bin_state,2)

    def binNeighborhood2Neighborhood(self,binNeighborhood):
        neighborhood = []
        for bin_state in list(binNeighborhood):
            neighborhood += [int(bin_state,2)]
        return neighborhood

    # flip the bit in the state indicated by the bit_idx
    def flipBit(self, bin_state, bit_idx):
        bit_idx = bit_idx + 2
        bit_value = bin_state[bit_idx]
        if bit_value == '0':
            bin_state = bin_state[:bit_idx] + '1' + bin_state[(bit_idx+1):]
        else:
            bin_state = bin_state[:bit_idx] + '0' + bin_state[(bit_idx+1):]
        return bin_state

    def getBinNeighborhood(self, bin_state, start_idx, num_to_flip):
        if num_to_flip == 0:
            return [bin_state]
        if num_to_flip == 1:
            neighborhood = []
            for bit_idx in range(start_idx, self.num_bits):
                new_state = self.flipBit(bin_state,bit_idx)
                neighborhood += [new_state]
            return neighborhood
        else:
            neighborhood = []
            for bit_idx in range(start_idx, self.num_bits):
                new_state = self.flipBit(bin_state,bit_idx)
                neighborhood += self.getBinNeighborhood(new_state, bit_idx+1, num_to_flip-1)
        return neighborhood

    def getNeighborhood(self, state, number_of_steps):
        bin_state = self.state2bin(state)
        bin_neighborhood = self.getBinNeighborhood(bin_state,0,number_of_steps)
        return self.binNeighborhood2Neighborhood(bin_neighborhood)
